from os import path

import taichi as ti

from .texture import getImage
from .common import Vec3, Vec2, toSpherical, mix, Ray3
from ..config import EPS_SKY_BOX

__all__ = [
    "skyLike",
    "darkSkyLike",
    "withSun",
    "SkyBox",
    "SkySphere",
]


@ti.func
def skyLike(ray: Ray3) -> Vec3:
    """超经典的蓝白天空盒"""
    norm = ray.direct.norm()
    t = ray.direct.y / norm if norm > 0 else 0
    return mix(1, Vec3(0.5, 0.7, 1), (t + 1) / 2)


@ti.func
def darkSkyLike(ray: Ray3) -> Vec3:
    """比上面那个暗一些"""
    return (skyLike(ray) - 0.5) * 0.25


def withSun(func, light: Vec3, direct: Vec3, angle):
    direct = Vec3(direct).normalized()

    @ti.func
    def inner(ray) -> Vec3:
        ret = Vec3(light)
        if ray.direct.normalized().dot(direct) < angle:
            ret = func(ray)
        return ret

    return inner


class SkyBox:
    """天空盒"""

    def __init__(self, fileDir: str) -> None:
        """从指定文件夹中加载六个大小相等的正方形图片，作为天空贴图"""
        self.N = getImage(path.join(fileDir, "panorama_0.png"))
        self.E = getImage(path.join(fileDir, "panorama_1.png"))
        self.S = getImage(path.join(fileDir, "panorama_2.png"))
        self.W = getImage(path.join(fileDir, "panorama_3.png"))
        self.U = getImage(path.join(fileDir, "panorama_4.png"))
        self.D = getImage(path.join(fileDir, "panorama_5.png"))

    @ti.func
    def sample(self, ray: Ray3) -> Vec3:
        """返回光线在天空盒上的采样"""
        point = ray.direct / abs(ray.direct).max()
        color = Vec3(0)
        fact = 1 - 1e-3
        if abs(abs(point.x) - 1) < EPS_SKY_BOX:
            if point.x > 0:
                color = self.E.sample(Vec2(1, -1) * point.yz * 0.5 * fact + 0.5)
            else:
                color = self.W.sample(point.yz * 0.5 * fact + 0.5)
        elif abs(abs(point.y) - 1) < EPS_SKY_BOX:
            if point.y > 0:
                color = self.U.sample(Vec2(-1, 1) * point.zx * 0.5 * fact + 0.5)
            else:
                color = self.D.sample(point.zx * 0.5 * fact + 0.5)
        elif abs(abs(point.z) - 1) < EPS_SKY_BOX:
            if point.z > 0:
                color = self.N.sample(point.yx * 0.5 * fact + 0.5)
            else:
                color = self.S.sample(Vec2(1, -1) * point.yx * 0.5 * fact + 0.5)
        return color


class SkySphere:
    def __init__(self, file: str) -> None:
        self.texture = getImage(file)

    @ti.func
    def sample(self, ray: Ray3) -> Vec3:
        return self.texture.sample(toSpherical(ray.direct))
